import itertools
import json
from collections.abc import Iterable
from typing import Any

import jsonschema
import regex
from django.conf import settings
from django.utils import timezone

from sysreptor.pentests import cvss
from sysreptor.pentests.models import (
    Comment,
    CommentStatus,
    PentestFinding,
    PentestProject,
    ProjectType,
    ReportSection,
    ReviewStatus,
)
from sysreptor.pentests.rendering.error_messages import (
    ErrorMessage,
    MessageLevel,
    MessageLocationInfo,
    MessageLocationType,
)
from sysreptor.utils.fielddefinition.types import BaseField, CvssVersion, FieldDataType, FieldDefinition
from sysreptor.utils.fielddefinition.utils import iterate_fields
from sysreptor.utils.utils import find_all_indices


def text_snippet(text: str):
    snippet = (text or '').strip().replace('\n', ' ')
    if len(snippet) > 100:
        snippet = snippet[:100] + '...'
    return snippet


class ReportCheck:
    def location_info(self, obj, path=None):
        if isinstance(obj, PentestProject):
            return MessageLocationInfo(
                type=MessageLocationType.PROJECT,
                id=obj.id,
                name=obj.name,
            ).for_path(path)
        elif isinstance(obj, ReportSection):
            return MessageLocationInfo(
                type=MessageLocationType.SECTION,
                id=obj.section_id,
                name=obj.section_label,
            ).for_path(path)
        elif isinstance(obj, PentestFinding):
            return MessageLocationInfo(
                type=MessageLocationType.FINDING,
                id=obj.finding_id,
                name=obj.data.get('title'),
            ).for_path(path)
        elif isinstance(obj, ProjectType):
            return MessageLocationInfo(
                type=MessageLocationType.DESIGN,
                id=obj.id,
                name=obj.name,
            ).for_path(path)
        else:
            raise ValueError('Unsupported MessageLocationInfo')

    def check(self, project: PentestProject) -> Iterable[ErrorMessage]:
        return itertools.chain(
            self.check_project(project),
            *map(self.check_section, project.sections.all()),
            *map(self.check_finding, project.findings.all()),
        )

    def check_project(self, project: PentestProject) -> Iterable[ErrorMessage]:
        return []

    def check_section(self, section: ReportSection) -> Iterable[ErrorMessage]:
        return []

    def check_finding(self, finding: PentestFinding) -> Iterable[ErrorMessage]:
        return []


class FieldCheck(ReportCheck):
    def check_field(self, value: Any, definition: BaseField, location: MessageLocationInfo) -> Iterable[ErrorMessage]:
        return []

    def check_fields(self, data: dict, definition: FieldDefinition, location: MessageLocationInfo) -> Iterable[ErrorMessage]:
        for p, v, d in iterate_fields(value=data, definition=definition):
            yield from self.check_field(value=v, definition=d, location=location.for_path(p))

    def check_section(self, section) -> Iterable[ErrorMessage]:
        return self.check_fields(section.data, section.field_definition, self.location_info(section))

    def check_finding(self, finding) -> Iterable[ErrorMessage]:
        return self.check_fields(finding.data, finding.field_definition, self.location_info(finding))


class TodoCheck(FieldCheck):
    def check_field(self, value, definition, location) -> Iterable[ErrorMessage]:
        if isinstance(value, str):
            snippets = []
            for idx in itertools.chain(*map(lambda s: find_all_indices(value, s), ['TODO', 'ToDo', 'TO-DO', 'To-Do'])):
                snippets.append(text_snippet(value[idx:]))
            if snippets:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Unresolved TODO',
                    details='\n'.join(snippets))


class EmptyFieldsCheck(FieldCheck):
    def check_field(self, value, definition, location) -> Iterable[ErrorMessage]:
        if getattr(definition, 'required', False) and (value is None or value == '' or value == []):
            yield ErrorMessage(
                level=MessageLevel.WARNING,
                location=location,
                message='Empty field',
            )


class StatusCheck(ReportCheck):
    def check_status(self, obj: ReportSection | PentestFinding):
        if obj.status != ReviewStatus.FINISHED:
            yield ErrorMessage(
                level=MessageLevel.WARNING,
                location=self.location_info(obj=obj),
                message=f'Status is not "{ReviewStatus.FINISHED}"',
                details=f'Status is "{obj.status}", not status "{ReviewStatus.FINISHED}"',
            )

    def check(self, project: PentestProject) -> Iterable[ErrorMessage]:
        # If all findings and sections have status "in-progress", deactivate this check.
        # We assume that the users of the project do not use the review feature and statuses.
        # This removed unnecessary (and ignored) warnings if no statuses are used.
        if any(map(lambda s: s.status != ReviewStatus.IN_PROGRESS, project.sections.all())) or \
            any(map(lambda f: f.status != ReviewStatus.IN_PROGRESS, project.findings.all())):
            return super().check(project)
        else:
            return []

    def check_section(self, section: ReportSection) -> Iterable[ErrorMessage]:
        return self.check_status(section)

    def check_finding(self, finding: PentestFinding) -> Iterable[ErrorMessage]:
        return self.check_status(finding)


class CvssFieldCheck(FieldCheck):
    def check_field(self, value, definition, location):
        if definition.type == FieldDataType.CVSS:
            if value in [None, '', 'n/a']:
                return
            if not cvss.is_cvss(value):
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Invalid CVSS vector',
                    details=f'"{value}" is not a valid CVSS vector. Enter "n/a" when no CVSS vector is applicable.',
                )
            elif definition.cvss_version != CvssVersion.ANY and not value.startswith(definition.cvss_version.value):
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Invalid CVSS version',
                    details=f'"{value}" does not match the required CVSS version {definition.cvss_version.value}',
                )


class RegexPatternCheck(FieldCheck):
    def __init__(self) -> None:
        # Total timeout for all regex checks in a project
        self.timeout = settings.REGEX_VALIDATION_TIMEOUT

    def check_field(self, value, definition, location):
        if definition.type == FieldDataType.STRING and definition.pattern and value:
            try:
                start_time = timezone.now()
                res = regex.match(pattern=definition.pattern, string=value, timeout=self.timeout.total_seconds())
                self.timeout -= timezone.now() - start_time

                if not res:
                    yield ErrorMessage(
                        level=MessageLevel.WARNING,
                        location=location,
                        message='Invalid format',
                        details=f'Value "{value}" does not match pattern /{definition.pattern}/',
                    )
            except TimeoutError:
                yield ErrorMessage(
                    level=MessageLevel.ERROR,
                    location=location,
                    message='Regex timeout',
                    details='Regex timeout exceeded while validating field',
                )
            except regex.error as ex:
                yield ErrorMessage(
                    level=MessageLevel.ERROR,
                    location=location,
                    message='Invalid regex pattern',
                    details=f'Failed to compile regex pattern /{ex.pattern}/: {ex.msg}',
                )


class NumberFieldCheck(FieldCheck):
    def check_field(self, value, definition, location):
        if definition.type == FieldDataType.NUMBER and value is not None:
            if definition.minimum is not None and value < definition.minimum:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Number out of range',
                    details=f'Number {value} is less than the minimum value {definition.minimum}',
                )
            if definition.maximum is not None and value > definition.maximum:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Number out of range',
                    details=f'Number {value} is greater than the maximum value {definition.maximum}',
                )


class JsonFieldCheck(FieldCheck):
    def check_field(self, value, definition, location):
        if definition.type == FieldDataType.JSON and value:
            try:
                value_parsed = json.loads(value)
                if definition.schema:
                    jsonschema.validate(instance=value_parsed, schema=definition.schema)
            except json.JSONDecodeError as ex:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Invalid JSON',
                    details=f'Failed to parse JSON: {ex}',
                )
            except jsonschema.ValidationError as ex:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='JSON data does not match schema',
                    details=str(ex),
                )
            except (jsonschema.SchemaError, Exception) as ex:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location,
                    message='Invalid JSON schema',
                    details=str(ex),
                )


class UnresolvedCommentCheck(ReportCheck):
    def check_comments(self, comments: list[Comment], location: MessageLocationInfo) -> Iterable[ErrorMessage]:
        for comment in comments:
            if comment.status != CommentStatus.RESOLVED:
                yield ErrorMessage(
                    level=MessageLevel.WARNING,
                    location=location.for_path(comment.path.removeprefix('data.')),
                    message='Unresolved comment',
                    details=text_snippet(comment.text),
                )

    def check_section(self, section: ReportSection) -> Iterable[ErrorMessage]:
        return self.check_comments(section.comments.all(), self.location_info(section))

    def check_finding(self, finding: PentestFinding) -> Iterable[ErrorMessage]:
        return self.check_comments(finding.comments.all(), self.location_info(finding))


def run_checks(project) -> Iterable[ErrorMessage]:
    def perform_check(checker):
        try:
            return checker.check(project)
        except Exception as ex:
            return [ErrorMessage(
                level=MessageLevel.ERROR,
                location=MessageLocationInfo(type=MessageLocationType.OTHER),
                message='Error while checking data',
                details=str(ex),
            )]
    return list(itertools.chain(*map(perform_check, [
        TodoCheck(),
        EmptyFieldsCheck(),
        CvssFieldCheck(),
        RegexPatternCheck(),
        NumberFieldCheck(),
        JsonFieldCheck(),
        StatusCheck(),
        UnresolvedCommentCheck(),
    ])))
